import os
import random
import numpy as np
from taichi_file_utils.files import find_files, get_fps_from_filename, get_data_from_file, RotateClass


class DataSet(object):
    def __init__(self, dataset_folder, is_target_vel=False, dt=None, is_random_rotate=True):
        self.dataset_folder = dataset_folder
        self.files = find_files(dataset_folder)
        self.is_target_vel = is_target_vel
        self.dt = dt
        self.is_random_rotate = is_random_rotate

    def get_batch(self, batch=1):
        if batch != 1:
            raise Exception('no code for batch != 1')  # todo no code for batch != 1

        while True:
            select_file = random.choice(self.files)
            print(select_file)
            self.fps = get_fps_from_filename(select_file)
            label_file = os.path.join(os.path.dirname(select_file), str(self.fps + 1) + '.csv')
            if os.path.exists(label_file):
                break

        self.current_data = self._build_data(get_data_from_file(select_file), get_data_from_file(label_file),
                                             target_vel=self.is_target_vel, dt=self.dt,
                                             random_rotate=self.is_random_rotate)
        return self.current_data

    def _build_data(self, table_data, label_data, dt=None, target_vel=False, random_rotate=False):
        pos, vel, phase = table_data[:, :3], table_data[:, 3:6], table_data[:, 6:7]
        l_pos, l_vel, l_phase = label_data[:, :3], label_data[:, 3:6], label_data[:, 6:7]

        if not np.all(np.equal(phase, l_phase)):
            print("change between two fps!")
            return None

        out = l_vel
        # accel instead of vel
        if not target_vel:
            out -= vel
            out /= dt

        if random_rotate:
            rotate = RotateClass()
            pos[:, 0], pos[:, 2] = rotate(pos[:, 0], pos[:, 2])
            vel[:, 0], vel[:, 2] = rotate(vel[:, 0], vel[:, 2])
            out[:, 0], out[:, 2] = rotate(out[:, 0], out[:, 2])

        data = np.concatenate([pos, vel, phase, out], axis=1)
        return data


if __name__ == '__main__':
    dataset = DataSet()
    data = dataset.get_batch()
    # print(data)
